#!/usr/bin/env python3

from pathlib import Path
from dataclasses import dataclass
import shutil
import subprocess

C_COPYRIGHT = """/*
 * Copyright (c) 2025 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
"""

SH_COPYRIGHT = """# Copyright (c) 2025 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""


CLANG_FORMAT_STYLE = r"""BasedOnStyle: Google
---
Language: JavaScript
JavaScriptQuotes: Double
IndentWidth: 4

---
Language:        Cpp
BasedOnStyle:  Google
AccessModifierOffset: -4
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlines: Left
AlignOperands:   true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:   
  AfterClass:      false
  AfterControlStatement: false
  AfterEnum:       false
  AfterFunction:   true
  AfterNamespace:  false
  AfterObjCDeclaration: false
  AfterStruct:     false
  AfterUnion:      false
  AfterExternBlock: false
  BeforeCatch:     false
  BeforeElse:      false
  IndentBraces:    false
  SplitEmptyFunction: true
  SplitEmptyRecord: true
  SplitEmptyNamespace: true
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializers: BeforeColon
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit:     120
CommentPragmas:  '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat:   false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros:   
  - foreach
  - Q_FOREACH
  - BOOST_FOREACH
IncludeBlocks:   Regroup
IncludeCategories: 
  - Regex:           '^<ext/.*\.h>'
    Priority:        2
  - Regex:           '^<.*\.h>'
    Priority:        1
  - Regex:           '^<.*'
    Priority:        2
  - Regex:           '.*'
    Priority:        3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth:     4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd:   ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 4
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Right
ReflowComments:  true
SortIncludes:    false
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles:  false
SpacesInContainerLiterals: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard:        Cpp11
TabWidth:        4
UseTab:          Never
"""


def check_should_ignore(file: str) -> bool:
    if file == ".clang-format":
        return True

    if file.startswith(".githooks") or file.startswith(".github"):
        return True

    if file.startswith(("cookbook/", "docs/")):
        return True

    if (
        file.startswith("test/")
        and not file.startswith("test/ani_")
        and not file == "test/CMakeLists.txt"
    ):
        return True
    
    if file.endswith(__file__.rsplit("/", 1)[-1]):
        return True

    return False


def format_code(text: str, clang_format_path: Path, assume_filename: str) -> str:
    process = subprocess.run(
        [
            "clang-format",
            f"-style=file:{clang_format_path}",
            f"--assume-filename={assume_filename}",
        ],
        input=text,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        check=True,
    )

    if process.returncode != 0:
        raise RuntimeError(f"clang-format error: {process.stderr}")

    return process.stdout


def py_move_ellipsis(text: str) -> str:
    result = []
    for line in text.splitlines():
        if line.rstrip().endswith(": ..."):
            count = len(line) - len(line.lstrip())
            result.append(line.rstrip()[:-4])
            result.append(" " * count + "    ...")
        else:
            result.append(line)
    return "\n".join(result)


def py_remove_asserts(text: str) -> str:
    @dataclass
    class AssertBlock:
        indent: int
        inner_count: int = 0

    result = []
    assert_block: AssertBlock | None = None
    for line in text.splitlines():
        indent = len(line) - len(line.lstrip())
        if assert_block is not None:
            if indent == assert_block.indent:
                if assert_block.inner_count == 0:
                    result.append(line)
                assert_block = None
            elif indent == 0:
                pass
            else:
                assert_block.inner_count += 1
        elif line.lstrip().startswith("assert "):
            assert_block = AssertBlock(indent=indent)
            result.append(" " * indent + "pass")
        else:
            result.append(line)
    return "\n".join(result)


def add_c_style_copyright(text: str) -> str:
    return C_COPYRIGHT + text


def add_sh_style_copyright(text: str) -> str:
    lines = text.splitlines()

    if len(lines) and lines[0].startswith("#!"):
        lines = lines[:1] + [SH_COPYRIGHT] + lines[1:]
    else:
        lines = [SH_COPYRIGHT] + lines

    return "".join(line + "\n" for line in lines)


def add_python_encoding_header(text: str) -> str:
    lines = text.splitlines()

    header = ["# coding=utf-8", "#"]

    # 如果已有 encoding 声明，就不重复加
    if any("coding=" in line for line in lines[:2]):
        return text

    # 保持 shebang 在第一行
    if lines and lines[0].startswith("#!"):
        return "\n".join([lines[0]] + header + lines[1:]) + "\n"
    else:
        return "\n".join(header + lines) + "\n"


def add_shebang(text: str) -> str:
    lines = text.splitlines()
    header = ["#!/bin/bash"]
    if lines and lines[0].startswith("#!"):
        return text
    else:
        return "\n".join(header + lines) + "\n"


def convert_to_include_guard(text: str, filename: str) -> str:
    should_nolint = filename.endswith(".hpp") or filename.endswith("common.h")

    lines = text.splitlines()

    should_add_include_guard = any(line.strip() == "#pragma once" for line in lines)

    # 删除已存在的 #pragma once
    lines = text.splitlines()
    filtered_lines = [line for line in lines if not line.strip() == "#pragma once"]

    # 生成 include guard 宏名
    guard_name = filename.upper().replace(".", "_").replace("/", "_").replace("\\", "_")
    guard_name = f"{guard_name}_"

    # 添加 include guard
    begin_guard_lines = []
    if should_add_include_guard:
        begin_guard_lines.append(f"#ifndef {guard_name}")
        begin_guard_lines.append(f"#define {guard_name}")
    if should_nolint:
        begin_guard_lines.append("// NOLINTBEGIN")

    end_guard_lines = []
    if should_nolint:
        end_guard_lines.append("// NOLINTEND")
    if should_add_include_guard:
        end_guard_lines.append(f"#endif  // {guard_name}")

    return "\n".join(begin_guard_lines + filtered_lines + end_guard_lines)


def remove_single_line_comments(text: str) -> str:
    return "\n".join(line for line in text.splitlines() if set(line.strip()) != {"/"})


def is_text_file(file: Path):
    try:
        with open(file, "r", encoding="utf-8") as f:
            f.read(1024)
        return True  # 可以作为 UTF-8 解码，可能是文本文件
    except UnicodeDecodeError:
        return False  # 不能作为 UTF-8 解码，可能是二进制文件


def process_file(
    input_dir: Path,
    output_dir: Path,
    file: str,
    clang_format_path: Path,
) -> None:
    input_path = input_dir / file
    output_path = output_dir / file

    print(f"Processing {input_path} -> {output_path}")

    output_path.parent.mkdir(parents=True, exist_ok=True)

    if not is_text_file(input_path):
        shutil.copy2(input_path, output_path)
        return

    with open(input_path, "r", encoding="utf-8") as input:
        text = input.read()

    # 对于 C/C++ 源文件和头文件，删除单行注释
    if file.endswith((".c", ".cpp", ".h", ".hpp")):
        text = remove_single_line_comments(text)

    # 转换 #pragma once 为 include guard
    if file.endswith((".h", ".hpp", ".c", ".cpp")):
        text = convert_to_include_guard(text, file)

    if file.endswith((".md", ".json")):
        pass
    elif file.endswith((".taihe", ".ohidl", ".cpp", ".c", ".hpp", ".h", ".ets", ".ts", ".g4")):
        text = add_c_style_copyright(text)
    else:
        text = add_sh_style_copyright(text)
        if file.endswith((".py")):
            text = add_python_encoding_header(text)
        elif file.endswith(("common_lib.sh")):
            text = add_shebang(text)

    if file.endswith(".c"):
        text = format_code(text, clang_format_path, "text.c")
    if file.endswith(".cpp"):
        text = format_code(text, clang_format_path, "text.cpp")
    if file.endswith(".h"):
        text = format_code(text, clang_format_path, "text.h")
    if file.endswith(".hpp"):
        text = format_code(text, clang_format_path, "text.hpp")
    elif file.endswith((".ts", ".ets")):
        text = format_code(text, clang_format_path, "text.js")
    if file.endswith(".py"):
        text = py_move_ellipsis(text)
        text = py_remove_asserts(text)

    with open(output_path, "w", encoding="utf-8") as output:
        output.write(text)


def git_get_valid_files():
    result = subprocess.run(
        ["git", "ls-tree", "--full-tree", "-r", "--name-only", "HEAD"],
        capture_output=True,
        text=True,
        check=True,
    )
    directories = result.stdout.strip().split("\n")
    return directories


def git_get_repo_root():
    result = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        capture_output=True,
        text=True,
        check=True,
    )
    return result.stdout.strip()


def main() -> None:
    file_list = git_get_valid_files()
    input_dir = Path(git_get_repo_root())
    output_dir = input_dir / "hw_exported"
    shutil.rmtree(output_dir, ignore_errors=True)
    output_dir.mkdir(parents=True)

    clang_format_path = output_dir / ".clang-format"
    with open(clang_format_path, "w") as file:
        file.write(SH_COPYRIGHT + CLANG_FORMAT_STYLE)

    for file in file_list:
        if check_should_ignore(file):
            continue
        process_file(input_dir, output_dir, file, clang_format_path)


if __name__ == "__main__":
    main()
